# _*_ coding: utf-8 _*_
# @Date : 2023/3/24 16:36
# @Author : Paul
# @File : naive_bayes.py
# @Description :

from classifiers.classifier import Classifier
from core.utils.log_util import LogUtil
from core.beans.classifier_result import ClassifierResult
from sklearn.naive_bayes import GaussianNB
from core.utils.date_util import DateUtil
from core.utils.string_utils import StringUtils
import json
import sys

class TWSNavieBayes(Classifier):

    def __init__(self,
                 app_name="svm",
                 data_source_id=None,
                 table_name=None,
                 feature_cols=None,
                 class_col=None,
                 train_size=None,
                 param=None
                 ):
        """
        初始化
        :param app_name:
        :param data_source_id:
        :param table_name:
        :param feature_cols:
        :param class_col:
        :param train_size:
        """
        super(TWSNavieBayes, self).__init__(app_name=app_name,
                                             data_source_id=data_source_id,
                                             table_name=table_name,
                                             feature_cols=feature_cols,
                                             class_col=class_col,
                                             train_size=train_size,
                                             param=param)
        self.IS_MODEL_EVAL = True  # 默认：不需要评估模型
        # 预测效果分布图
        self.tree_pred_image = "descion_tree_pred_" + app_name + "_" + DateUtil.getCurrentDateSimple()

    def initModel(self):
        """
        初始化模型
        """
        algoParam = self.param["algoParam"]
        priors = None if StringUtils.isBlack(algoParam["priors"]) else str(algoParam["priors"])
        var_smoothing = 1e-9 if StringUtils.isBlack(algoParam["varSmoothing"]) else float(algoParam["varSmoothing"])
        self.clf = GaussianNB(priors=priors,
                              var_smoothing=var_smoothing)

    def buildModel(self, train_data):
        """
        训练模型
        """
        Xtrain = train_data[0]
        Ytrain = train_data[1]
        self.clf = self.clf.fit(Xtrain, Ytrain)

    def evalModel(self, train_data, test_data):
        """
        评估模型
        """
        Xtest = test_data[0]
        Ytest = test_data[1]
        score_ = self.clf.score(Xtest, Ytest)

        # 结束时间
        end_time = DateUtil.getCurrentDate()
        cost_second = DateUtil.diffMin(self.start_time, end_time)

        # 模型结果存入mysql
        algo_result = ClassifierResult(self.param["id"],
                                       "naive_bayes",
                                       self.param,
                                       self.app_name,
                                       self.info,
                                       self.describe,
                                       None,
                                       None,
                                       score_,
                                       "sucess",
                                       self.start_time,
                                       end_time,
                                       cost_second)
        LogUtil.saveClassifierResult(self.meta_data_source, algo_result)


if __name__ == '__main__':
    argv = sys.argv[1]
    # argv = "{\"algoParam\":{\"priors\":\"\",\"varSmoothing\":\"\"},\"appName\":\"kmeans_1\",\"classCols\":\"Survived\",\"dataSourceId\":\"9\",\"featureCols\":\"Pclass,Sex,Age,SibSp,Parch,Fare,Embarked\",\"id\":\"1679648759018\",\"preProcessMethodList\":[{\"preProcessFeature\":\"Age\",\"preProcessMethod\":\"fillna\",\"preProcessMethodValue\":\"mean\"},{},{\"preProcessFeature\":\"Sex\",\"preProcessMethod\":\"transClassFeature\"},{\"preProcessFeature\":\"Embarked\",\"preProcessMethod\":\"transClassFeature\"}],\"tableName\":\"titanic\",\"trainDataRatio\":\"0.7\"}"
    param = json.loads(argv)
    app_name = param["appName"]
    data_source_id = param["dataSourceId"]
    table_name = param["tableName"]
    feature_cols = param["featureCols"]
    class_cols = param["classCols"]
    train_size = float(param["trainDataRatio"])
    class_cols_list = []
    if isinstance(class_cols, list):
        class_cols_list = class_cols
    else:
        class_cols_list.append(class_cols)
    classifier = TWSNavieBayes(app_name=app_name,
                                data_source_id=data_source_id,
                                table_name=table_name,
                                feature_cols=feature_cols,
                                class_col=class_cols_list,
                                train_size=train_size,
                                param=param)
    if "paramTrain" not in param.keys():
        classifier.execute()
    else:
        classifier.paramTrain()
